{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dbed36fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import random\n",
    "\n",
    "def select_and_copy_files(source_dir_images, source_dir_labels, dest_dir_images, dest_dir_labels, num_files, categories):\n",
    "    for category in categories:\n",
    "        source_subdir_images = os.path.join(source_dir_images, category)\n",
    "        source_subdir_labels = os.path.join(source_dir_labels, category)\n",
    "        dest_subdir_images = os.path.join(dest_dir_images, category)\n",
    "        dest_subdir_labels = os.path.join(dest_dir_labels, category)\n",
    "\n",
    "        # 创建目标子目录，如果它们不存在的话\n",
    "        os.makedirs(dest_subdir_images, exist_ok=True)\n",
    "        os.makedirs(dest_subdir_labels, exist_ok=True)\n",
    "\n",
    "        # 获取源目录中的所有.jpg和.txt文件\n",
    "        all_images = [f[:-4] for f in os.listdir(source_subdir_images) if os.path.isfile(os.path.join(source_subdir_images, f)) and f.endswith('.jpg')]\n",
    "        all_labels = [f[:-4] for f in os.listdir(source_subdir_labels) if os.path.isfile(os.path.join(source_subdir_labels, f)) and f.endswith('.txt')]\n",
    "\n",
    "        # 找到同时存在于两个目录中的文件\n",
    "        common_files = list(set(all_images) & set(all_labels))\n",
    "\n",
    "        if len(common_files) < num_files[category]:\n",
    "            raise ValueError(f\"Only {len(common_files)} common files found in {category}. Cannot select {num_files[category]} files.\")\n",
    "\n",
    "        selected_files = random.sample(common_files, num_files[category])\n",
    "\n",
    "        for file in selected_files:\n",
    "            src_file_image = os.path.join(source_subdir_images, file + '.jpg')\n",
    "            src_file_label = os.path.join(source_subdir_labels, file + '.txt')\n",
    "            dst_file_image = os.path.join(dest_subdir_images, file + '.jpg')\n",
    "            dst_file_label = os.path.join(dest_subdir_labels, file + '.txt')\n",
    "\n",
    "            shutil.copy(src_file_image, dst_file_image)\n",
    "            shutil.copy(src_file_label, dst_file_label)\n",
    "\n",
    "# 在以下目录进行操作\n",
    "source_directory_images = \"/data/jiayuan/BDDcoco/yolo_v8/images\"\n",
    "source_directory_labels = \"/data/jiayuan/BDDcoco/yolo_v8/labels\"\n",
    "destination_directory_images = \"/data/jiayuan/BDDcoco/yolo_v8_toy/images\"\n",
    "destination_directory_labels = \"/data/jiayuan/BDDcoco/yolo_v8_toy/labels\"\n",
    "\n",
    "# 你想从中选择文件的子目录，以及每个子目录下要选择的文件数量\n",
    "categories = ['train2017', 'val2017']\n",
    "num_of_files = {'train2017': 100, 'val2017': 10}\n",
    "\n",
    "# 选择并复制.jpg和.txt文件\n",
    "select_and_copy_files(source_directory_images, source_directory_labels, destination_directory_images, destination_directory_labels, num_of_files, categories)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3193eea8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:yolo8]",
   "language": "python",
   "name": "conda-env-yolo8-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
